import sys
import signal
import xml.etree.ElementTree as ElementTree
import zipfile
import os
import re
import random
import json
import gzip
import datetime
import argparse
import code

import nltk
import numpy as np
import pandas as pd

from config import CLASS_VERSION,EPO_BASE

base = EPO_BASE

PATH = "datasets/V{}".format(CLASS_VERSION)

#define a function to get the text of a patent description
def element_to_text(element):
    all_text = ""
    for text in element.itertext():
        all_text = all_text + text
    return all_text

#opens a zip file and reads the patent description using the function element_to_text
def read_EPO_zip(zpf):
    patent = {}
    with zipfile.ZipFile(os.path.join(base,zpf),'r') as zf:

        f = zf.open(os.path.basename(zpf).replace(".zip",".xml"))
        tree = ElementTree.parse(f)
        root = tree.getroot()
        if root.tag !=  "ep-patent-document":
            raise Exception("Unknown patent document type: " + root.tag)

        patent.update(root.attrib)
        del patent["file"]
        del patent["dtd-version"]

        # The following two statements should hold for all A1 and B1 kinds
        # (other kinds have different prefixes than NW)
        assert(patent["kind"] in ["A1","B1"])
        docid = os.path.basename(zpf).replace(".zip","")
        patent["epoid"] = docid
        assert(patent["epoid"][-4:] == "NW{}".format(patent["kind"]))

        patent["appln_nr"] = patent["epoid"][:-4]
        patent["pub_nr"] = "EP{}{}".format(patent["doc-number"],patent["kind"])

        description = root.find("description")
        if description:
            description = element_to_text(description)
            patent["description"] = description

    return patent

# Set of all kind dirs:
# EPDLA9,EPNWA1
# EPNWA2,EPNWA3
# EPNWA8,EPNWA9
# EPNWB1,EPNWB2
# EPNWB3,EPW1A8
# EPW1A9,EPW1B8
# EPW1B9,EPW2A8
# EPW2A9,EPW2B8
# EPW2B9,EPW3A8
# EPW3A9,EPW3B8
# EPW3B9,EPW4B9

#Creating a way to navigate the database
#iterates over all patents in the database and creates a filesystem
class EPODatabase(object):
    KINDS = ["A1","B1"]
    def __init__(self,root_dir):
        self.root_dir = root_dir

    def _week_dir(self,year,week):
        week = str(week)
        year = str(year)
        wk = "EPRTBJV{}0000{}001001".format(year,week.zfill(2))
        return wk

    def get_weeks(self,year):
        week_dirs = os.listdir(os.path.join(self.root_dir,year))
        weeks = []
        for week_dir in week_dirs:
            assert(len(week_dir) == 23)
            weeks.append(week_dir[15:17])
        weeks.sort()
        return weeks

    def _mk_kind_dir(self,kind):
        return "EPNW{}".format(kind)
        
    def kind_exists(self,year,week,kind):
        wd = self._week_dir(year,week)
        d = os.path.join(year,wd,"DOC",self._mk_kind_dir(kind))
        return os.path.isdir(os.path.join(self.root_dir,d))

    def get_patents(self,year,week,kind):
        year = str(year)
        wk = self._week_dir(year,week)
        ser = self._mk_kind_dir(kind)
        files = os.listdir(os.path.join(self.root_dir,year,wk,"DOC",ser))

        files = filter(lambda zipname : not zipname.startswith("."),files)
        
        def patid(zipname):
            if not zipname.endswith(".zip"):
                print((year,week,kind,zipname))
            assert(zipname.endswith(".zip"))
            return zipname[:-4]
        patents = map(patid,files)
        return list(patents)

    def iterator(self,filters={}):
        years = os.listdir(self.root_dir)
        years = list(filter(lambda x:x.isdigit(),years))
        years.sort()
     
        collections = []

        for year in years:
            weeks = self.get_weeks(year)
            if "year" in filters and year not in filters["year"]:
                continue
            for week in weeks:
                if "week" in filters and week not in filters["week"]:
                    continue
                for kind in EPODatabase.KINDS:
                    if "kind" in filters and kind not in filters["kind"] or not self.kind_exists(year,week,kind):
                        continue
                    collections.append((year,week,kind))
        
        for (year,week,kind) in collections:
            for pat in self.get_patents(year,week,kind):
                yield (year,week,kind,pat)

    def read_patent(self,year,week,patid):
        kind = patid[-2:]
        
        zip_path = os.path.join(self.root_dir,str(year),self._week_dir(year,week),"DOC",self._mk_kind_dir(kind),"{}.zip".format(patid,".zip"))
        
        p = read_EPO_zip(zip_path)
        p["year"] = year
        p["week"] = week
        p["patid"] = patid

        return p


    def save_index(self,file_name):
        with gzip.GzipFile(file_name,"w") as f:
            f.write(json.dumps(self.index))
    
    def load_index(self,file_name):
        with gzip.GzipFile(file_name,"r") as f:
            self.index = json.loads(f.read())
    


class Location(object):
    @staticmethod
    def encode(year,week,patid):
        return "{}{}{}".format(year,week.zfill(2),patid)

    @staticmethod
    def isloc(s):
        if len(s) <= 6:
            return False
        year = s[:4]
        week = s[4:6]
        patid = s[6:]
        return year.isdigit() and week.isdigit() and patid.startswith("EP")

    @staticmethod
    def decode(s):
        year = s[:4]
        week = s[4:6]
        patid = s[6:]
        assert(year.isdigit() and week.isdigit() and patid.startswith("EP"))
        return (year,week,patid)


class EPODatabaseIndex(object):
    """This class creates and manages an index of the EPO Database.
    Currently it supports two kinds ("A1" and "B1").
    The index contains mappings between application number APN and publication number PUBN
    as well as `loc`, which is the unique identifier of a patent.
    The mappings are stored in the dictionary index which contains apn2pubn, ... mappings.
    Note that len(apn2loc) != len(loc2apn). The reason is that the application number which includes the kind
    does NOT uniquely identify the document. The same appln_nr and kind may have different versions, identified
    by their publication date.
    This index abstracts from that and will always map consider the latest publication.
    A more correct index would add publication date to the ID (incurring the disadvantage that search
    mechanisms would have to be more complicated.
    """
    def __init__(self,epo,dbpath):
        self.epo = epo
        self.dbpath = dbpath
        if os.path.dirname(dbpath)!= '' and not os.path.exists(os.path.dirname(dbpath)):
            raise IOError("Directory not found: '{}'".format(os.path.dirname(dbpath)))
        if os.path.exists(dbpath):
            with gzip.open(dbpath,"rt") as f:
                self.index = json.loads(f.read())
        else:
            self.index = None

    
    def build_index(self,feature_generator=None):
        self.index = {'apn2pubn':{},'pubn2apn':{},'apn2lang':{},'apn2loc':{},'loc2apn':{},'apn2hasdesc':{}}
        feature_map = {}

#        start = 1410000 + 1

        if True:
            print("Loading index...")
            with gzip.open(self.dbpath,"rt") as f:
                self.index = json.loads(f.read())
            with gzip.open("features.json.gz","rt") as f:
                feature_map = json.loads(f.read())

        print("Building index...")
        i = 0
#        print("[starting at i={}]".format(start))
        # for (year,week,kind,pat) in self.epo.iterator({'kind':'A1','year':'1980','week':'02'}):
        for (year,week,kind,pat) in self.epo.iterator():
            i += 1
            loc = Location.encode(year,week,pat)
            if loc in self.index["loc2apn"]:
                continue
            apn = pat.replace("NW","")
            try:
                patent = self.epo.read_patent(year,week,pat)
            except KeyError as e:
                print("Could not read patent {}: {}".format(loc,e))
                continue
                
            self.index["apn2pubn"][apn] = patent["pub_nr"]
            self.index["pubn2apn"][patent["pub_nr"]] = apn
            self.index["apn2loc"][apn] = loc

            self.index["loc2apn"][loc] = apn
            self.index["apn2lang"][apn] = patent["lang"]

            if "description" in patent:
                self.index["apn2hasdesc"][apn] = True
            else:
                self.index["apn2hasdesc"][apn] = False
                
            if feature_generator:
                feature_map[apn] = feature_generator.generate(patent)

                
            if i % 2000 == 0:
                if i % 50000 == 0:
                    print("{}: {}/{}: processed {} patents (saved checkpoint)".format(datetime.datetime.now(),year,week,i))
                    self.save_index()
                    if feature_generator:
                        with gzip.open("features.json.gz","wt") as f:
                            f.write(json.dumps(feature_map))
                else:
                    print("{}: {}/{}: processed {} patents".format(datetime.datetime.now(),year,week,i))
                        
        self.save_index()
        if feature_generator:
            with gzip.open("features.json.gz","wt") as f:
                f.write(json.dumps(feature_map))
        
        return feature_map

    def save_index(self):
        with gzip.open(self.dbpath,"wt") as f:
            f.write(json.dumps(self.index))
        


class EPODatabaseManager(object):
    def __init__(self,epo_index):
        self.epo_index = epo_index
        self.index = epo_index.index

    def read_patent(self,appln_nr_or_loc):
        if Location.isloc(appln_nr_or_loc):
            loc = appln_nr_or_loc
        else:
            loc = self.index["apn2loc"][appln_nr_or_loc]
        (year,week,patid) = Location.decode(loc)
        pat = self.epo_index.epo.read_patent(year,week,patid)
        return pat

    def iterator(self):
        locs = list(self.index["apn2loc"].values())
        apns = list(self.index["apn2loc"].keys())
        idx = sorted(range(len(locs)),key=lambda i : locs[i])
        
        for i in idx:
            yield((apns[i],locs[i]))
        

class ApplicationPatentManager(object):
    """This class manages a mapping from the application (number without kind) to the latest patent document.
    It provides an efficient sequential iterator over these patents."""
    def __init__(self,edi,only_deen=True):
        self.edi = edi
        self.only_deen = only_deen
        self._build_appln_nr_index()

    def _build_appln_nr_index(self):
        self.appln2apn = {}
        for apn in self.edi.index["apn2pubn"]:
            assert(apn.endswith("A1") or apn.endswith("B1"))
            appln_nr = apn[:-2]
            kind = apn[-2:]
            if appln_nr in self.appln2apn and kind == "A1":
                continue
            if self.edi.index["apn2hasdesc"][apn]:
                if not self.only_deen or self.only_deen and self.edi.index["apn2lang"][apn] == "en":
                    self.appln2apn[appln_nr] = apn


    def iterator(self):
        """This is an efficient iterator over (appln_nr,patent_loc)"""
        dbm = EPODatabaseManager(self.edi)

        for (apn,loc) in dbm.iterator():
            appln_nr = apn[:-2]
            #Recall: appln2apn only contains patents with English descriptions
            if appln_nr in self.appln2apn and self.appln2apn[appln_nr] == apn:
                yield (apn,loc)

def word(wrd):
    return "[^\w]{}[^\w]".format(wrd)

def caseInsensitive(regexstr):
    return re.compile(regexstr,re.IGNORECASE)

def caseSensitive(regexstr):
    return re.compile(regexstr)

class ButNotSearcher(object):
    def __init__(self,appear,notappear):
        self.appear = appear
        self.notappear = notappear
        
    def search(self,text):
        appear = self.appear.search(text)
        if appear is None:
            return None
        notappear = self.notappear.search(text)
        if notappear is not None:
            return None
        return appear
        

class WordGroupsInSameSentenceFinder(object):
    def __init__(self,list1,list2,stop_count=1,caseSensitive1=False,caseSensitive2=False):
        def make_re(w):
            return 

        if caseSensitive1:
            build_re1 = lambda w : re.compile(w) if isinstance(w,str) else w
        else:
            build_re1 = lambda w : re.compile(w,re.IGNORECASE) if isinstance(w,str) else w
            
        if caseSensitive2:
            build_re2 = lambda w : re.compile(w) if isinstance(w,str) else w
        else:
            build_re2 = lambda w : re.compile(w,re.IGNORECASE) if isinstance(w,str) else w

        self.list1 = list(map(build_re1,list1))
        self.list2 = list(map(build_re2,list2))
        self.stop_count = stop_count

    def in_same_sentence(self,sents):
        count = 0
        for sent in sents:
            l1 = False
            for cr in self.list1:
                if cr.search(sent) is not None:
                    l1 = True
                    break
            if not l1:
                continue

            l2 = False
            for cr in self.list2:
                if cr.search(sent) is not None:
                    l2 = True
                    break
            if l2:
                count += 1
                if count == self.stop_count:
                    return count
        return count


    def in_same_patent(self,text):
        found = False
        for cr in self.list1:
            if cr.search(text) is not None:
                found = True
                break

        if not found:
            return False

        for cr in self.list2:
            if cr.search(text) is not None:
                return True

        return False

class AutomationKeywordsVersion6(object):
    """This the sixth (#6) attempt to classify automation patents still looks at occurence of words but also considers how many times these words appear and looks for keywords on a sentence-level."""
    def __init__(self):
        self.auto = re.compile(word("(automation|automatization)"),re.IGNORECASE)
        self.automat = re.compile(word("automat\w*"),re.IGNORECASE)
        self.automatautonomous = WordGroupsInSameSentenceFinder(map(word,["automat\w*","autonomous"]),map(word,["machine","vehicle system","welding","knitting","weaving","convey\w*","storage","store","operator","handling","regulat\w*","manipulat\w*","arm","sensor","inspect\w*","warehouse","manufacturing","machining","equipment","apparatus"]),stop_count=2)
        self.labor = re.compile(word("(labou?rious|labou?r)"),re.IGNORECASE)
        self.robot = re.compile("robot\w*",re.IGNORECASE)
        self.surmed = re.compile(word("(surgical|medical)"),re.IGNORECASE)
        self.NC = WordGroupsInSameSentenceFinder([word("NC")],map(word,["machine","apparatus","equipment","manufacturing","machining"]),caseSensitive1=True)
        
        ContentAddressableMemory = re.compile(word("content addressable memory"),re.IGNORECASE)
        CAMbutnotContentAddressableMemory = ButNotSearcher(caseSensitive(word("CAM")),ContentAddressableMemory)
        
        self.CADCAM = WordGroupsInSameSentenceFinder([word("CAD"),CAMbutnotContentAddressableMemory],map(word,["machine","apparatus","equipment","manufacturing","machining"]),caseSensitive1=True)
        self.numconalone = re.compile(word("(numerically[-\s]controlled|numeric[-\s]control|numerical[-\s]control)"),re.IGNORECASE)
        self.CNCalone = re.compile(word("CNC"))
        self.computeraided = WordGroupsInSameSentenceFinder([word("computer[-\s]?aided|computer[-\s]?assisted|computer[-\s]?supported")],map(word,["machine","apparatus","equipment","manufacturing","machining"]))
        self.threedee = re.compile(word("(3d print\w*|additive manufacturing|additive layer manufacturing)"),re.IGNORECASE)
        self.flexman = re.compile(word("flexible manufacturing"),re.IGNORECASE)
        
        self.PLC_acro_alone = re.compile(word("PLC"))
        self.PLC_name_alone = re.compile(word("Programmable Logic Controller"),re.IGNORECASE)
        self.powerline = re.compile(word("power[\s-]?line"),re.IGNORECASE)
        
        # this initializes self.none and self.labels
        self.generate({"description":""})
        
    def generate(self,patent):
        if "description" in patent:
            text = patent["description"]
            sents = nltk.sent_tokenize(text)

            # 1. Automat* patents
            automat = int(self.auto.search(text) is not None or len(self.automat.findall(text))>=5 or self.automatautonomous.in_same_sentence(sents)>=2)

            # 2. Labor patents
            labor = int(self.labor.search(text) is not None)

            # 3. Robot patents
            robot = int(self.robot.search(text) is not None and self.surmed.search(text) is None)

            # 4. CNC patents
            _CNC = self.numconalone.search(text) is not None or self.CNCalone.search(text) is not None
            _NCs = self.NC.in_same_sentence(sents)>=1
            CNC = int(_CNC or _NCs)

            # 5. CAD/CAM patents
            _computeraidedp = self.computeraided.in_same_patent(text)
            _CADCAM = self.CADCAM.in_same_sentence(sents)>=1
            CADCAM = int(_computeraidedp or _CADCAM)

            # 6. 3D printing patents
            threedee = int(self.threedee.search(text) is not None)
            
            # 7. Flexible Manufacturing
            flexman = int(self.flexman.search(text) is not None)

            # 8. PLC patents
            PLC = int(self.PLC_name_alone.search(text) is not None or (self.PLC_acro_alone.search(text) is not None and self.powerline.search(text) is None))

            # Union
            anyclassification = int(automat>0 or labor>0 or robot>0 or CNC>0 or CADCAM>0 or threedee>0 or flexman>0 or PLC>0)

            # This takes care of returning all the local variables defined above as a list
            # (It also takes care of initializing self.labels and self.none)
            v = self._make_ret(locals())
            if "labels" not in self.__dict__:
                self.labels = v["keys"]
                self.none = len(self.labels)*[None]
            return v["vals"]
        else:
            return self.none

    def _make_ret(self,loc):
        IGNORE = ["patent","self","text","sents"]
        keys = list(loc.keys())
        vals = list(loc.values())
        idx = sorted(range(len(keys)),key=lambda i : keys[i])
        vals = [vals[i] for i in idx if keys[i] not in IGNORE and not keys[i].startswith("_")]
        keys = [keys[i] for i in idx if keys[i] not in IGNORE and not keys[i].startswith("_")]
        return {'keys':keys,'vals':vals}
    
if __name__ == "__main__":    
    parser = argparse.ArgumentParser(description='Process EPO database.')
    parser.add_argument('--class_version',type=int,help='Version of the classification')
    parser.add_argument('--epo_base',type=str,help='Base directory of the EPO database')
    parser.add_argument('--task',type=str,help='Task to perform')
    args = parser.parse_args()

    CLASS_VERSION = args.class_version if args.class_version else CLASS_VERSION
    EPO_BASE = args.epo_base if args.epo_base else EPO_BASE
    task = args.task

    task = sys.argv[1]
    if task == "console":
        print("Loading database...")
        db = EPODatabase(EPO_BASE)
        edi = EPODatabaseIndex(db,"index.json.gz")

        dbm = EPODatabaseManager(edi)
        ac = AutomationKeywordsVersion6()

        print("Creating application index...")
        apm = ApplicationPatentManager(edi)
        print("Discovered {} applications".format(len(apm.appln2apn)))

        print("==== CONSOLE ====")
        print("- Type select(\"<apn>\") to select an APN")
        print("- Type show() to show its descrption")
        print("- Type gen() to generate features for patent")
        print("[Note: this is a python console. You can do everything else as you please!]")

        apn = None
        patent = None
        def select(apn_):
            global apn
            global patent
            apn_ = apn_.replace(" ","")
            apn = apn_
            try:
                patent = dbm.read_patent(apn)
            except Exception as e:
                print("ERROR: could not read patent")
                print(e)

        def show(apn_=None):
            global apn
            global patent
            if apn_:
                select(apn_)
            if patent and "description" in patent:
                print(patent["description"])
            elif patent:
                print("Patent has no description")
            else:
                print("No patent loaded")

        def gen(apn_=None):
            global apn
            global patent
            global ac
            if apn_:
                select(apn_)
            if not patent:
                print("No patent loaded")
                return

            print(pd.DataFrame({'f':ac.generate(patent),'l':ac.labels}))

        code.interact(local=locals())

    elif task == "build_index":
        db = EPODatabase(EPO_BASE)
        edi = EPODatabaseIndex(db,"index.json.gz")
        fmap = edi.build_index(feature_generator=AutomationCounter())
        with gzip.open("features.json.gz","wt") as f:
            f.write(json.dumps(fmap))

    elif task=="build_features":
        print("Loading database...")
        db = EPODatabase(EPO_BASE)
        edi = EPODatabaseIndex(db,"index.json.gz")

        dbm = EPODatabaseManager(edi)
        ac = AutomationKeywordsVersion6()
        ffn = "featuresV.json.gz".format(CLASS_VERSION)

        print("[Using feature generator {}]".format(type(ac).__name__))

        if os.path.exists(ffn):
            print("Loading existing features in {}...".format(ffn))
            with gzip.open(ffn,"rt") as f:
                feature_map = json.loads(f.read())
        else:
            print("Opening new features database {}".format(ffn))
            feature_map = {}

        print("Creating application index...")
        apm = ApplicationPatentManager(edi)
        print("Discovered {} applications".format(len(apm.appln2apn)))

        i = 0

        def sigint_handler(signal,frame):
            print("Caught SIGINT.. saving features and exiting...")
            with gzip.open(ffn,"wt") as f:
                f.write(json.dumps(feature_map))
            print("Saved features after processing {} patents".format(i))
            sys.exit(1)
        signal.signal(signal.SIGINT,sigint_handler)

        print("Building features...")
        print("{}: Starting...".format(datetime.datetime.now()))
        for (apn,loc) in apm.iterator():
            i += 1
            if apn in feature_map:
                continue
            pat = dbm.read_patent(apn)
            feature_map[apn] = ac.generate(pat)

            if i % 5000 == 0:
                if i % 50000 == 0:
                    print("{}: Processed {} patents (checkpoint)".format(datetime.datetime.now(),i))
                    with gzip.open(ffn,"wt") as f:
                        f.write(json.dumps(feature_map))
                else:
                    print("{}: Processed {} patents".format(datetime.datetime.now(),i))

        with gzip.open(ffn,"wt") as f:
            f.write(json.dumps(feature_map))

    elif task=="save_csv":
        print("Loading index...")
        db = EPODatabase(EPO_BASE)
        edi = EPODatabaseIndex(db,"index.json.gz")
        print("Creating application index...")
        apm = ApplicationPatentManager(edi)
        print("Discovered {} applications".format(len(apm.appln2apn)))
        print("Loading features...")
        with gzip.open("featuresV6.json.gz","rt") as f:
            feature_map = json.loads(f.read())
        print("Iterating...")
        data = []
        for (apn,loc) in apm.iterator():
            d = [apn[2:-2]]
            x = feature_map[apn]
            if x[0] == None:
                x = 12*[None]
            d.extend(x)
            data.append(d)

        print("Created features for {} applications.".format(len(data)))
        df = pd.DataFrame(data)
        ac = AutomationKeywordsVersion6()
        y = ["appln_nr"]
        y.extend(list(ac.labels))
        df.columns = tuple(y)
        print("Saving to csv...")
        df.to_csv('{}/appln_features.csv'.format(PATH),index=False)
        print("All done.")

    elif task=="restrict_tf":
        print("Loading features...")
        f = pd.read_csv("{}/appln_features.csv".format(PATH),dtype={'appln_nr':str})
        print("Loading...")
        year = pd.read_csv("patstat/appln_year.csv",dtype={'appln_nr':str})
        print("Merging with appln_year data...")
        m = f.merge(year,on="appln_nr",how="left")
        # Restrict to samples:
        m2 = m.loc[(m.appln_year >= 1997) & (m.appln_year <= 2011)]
        m3 = m.loc[(m.appln_year >= 1998)]
        m4 = m.loc[(m.appln_year <= 1997)]
        del m2["appln_year"]
        del m3["appln_year"]
        del m4["appln_year"]
        print("Writing features csv files...")
        m2.to_csv('{}-1997-2011/appln_features.csv'.format(PATH),index=False)
        m3.to_csv('{}-from-1998/appln_features.csv'.format(PATH),index=False)
        m4.to_csv('{}-until-1997/appln_features.csv'.format(PATH),index=False)

    elif task=="merge_csv":
        if len(sys.argv) > 2:
            path = "{}-{}".format(PATH,sys.argv[2])
        else:
            path = PATH

        print("Loading...")
        ipc4 = pd.read_csv("patstat/appln_ipc4.csv",dtype={'appln_nr':str})
        ipc6 = pd.read_csv("patstat/appln_ipc6.csv",dtype={'appln_nr':str})

        try:
            del ipc4["isauto"]
            del ipc6["isauto"]
        except:
            pass

        print("Loading features...")
        f = pd.read_csv("{}/appln_features.csv".format(path),dtype={'appln_nr':str})

        print("Merging...")
        ipc6new = ipc6.merge(f,on="appln_nr",how="inner")
        ipc4new = ipc4.merge(f,on="appln_nr",how="inner")

        ipc6new.to_csv("{}/appln_ipc6.csv".format(path),index=False)
        ipc4new.to_csv("{}/appln_ipc4.csv".format(path),index=False)

        
